import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
from matplotlib.patches import Patch
import matplotlib.lines as mlines
sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import *
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
import weightedBinscatter as wb
import estimators as et
#sys.path.append('packages')
#import binscatter
import matplotlib.font_manager as fm
import matplotlib
import statsmodels.formula.api as smf
from statsmodels.iolib.smpickle import load_pickle
import joblib
from trajectoryPlotters import *

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

colorsBNB = cdictBlackNonBlack()
colorsENE = cdictEITCNon()
cb = colorsBNB['Black']['color']
ab = colorsBNB['Black']['alpha']
cnb = colorsBNB['non-Black']['color']
anb = colorsBNB['non-Black']['alpha']
ce = colorsENE['eitc']['color']
cne = colorsENE['non']['color']
ae = colorsENE['eitc']['alpha']
ane = colorsENE['non']['alpha']

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
	coef = model.params[pbvarb]
	se = model.bse[pbvarb]
	black_aud_rate = model.params[0] + model.params[1]
	non_black_aud_rate = model.params[0] 
	return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
	non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
	est =  black_aud_rate - non_black_aud_rate
	return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
	seMultiplier=getSEMultiplier(dataset,pbvarb)
	#seReg = regEstimate(dataset,pbvarb,outcome)[1]
	seChen = seReg*seMultiplier
	return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
	return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

### OUTPUT

out = '/REDACTED/'

### OP AUDIT
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv")
len(dataBISG)
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]
len(dataBISG)
data = dataBISG.copy()
data['unitweight'] = 1
'''
### NC data
ncdata = pd.read_csv("/REDACTED/NC_Analysis_Dataset_2014_tplevel.csv")
### NC Weights
weights_old = pd.read_csv("/REDACTED/HertzGraff/nc_reweights_v4_April.csv")
weights = pd.read_stata('/REDACTED/BIFSG/nc_rewt_2014_coded_output_v4_December_ncsamp.dta')


len(ncdata)
ncdata = ncdata[ncdata.black_ind.notna()]
len(ncdata)
len(weights)
len(weights_old)
ncdata_old = pd.merge(ncdata, weights_old, on='taxpayer_id')
ncdata = pd.merge(ncdata,weights,on='taxpayer_id')
len(ncdata)
len(ncdata_old)


nc_BISG = dataBISG[dataBISG['state']=="NC"]
nc_BISG['pprob_black'] = nc_BISG['predicted_prob_black']
nc_BISG = nc_BISG[['taxpayer_id', 'pprob_black', 'aud_no_research_audits', 'isEIC']]

ncdata = pd.merge(ncdata,nc_BISG,on='taxpayer_id')
ncdata_old = pd.merge(ncdata_old,nc_BISG,on='taxpayer_id')
ncdata.drop(['predicted_prob_black', 'black_prob'], axis = 1, inplace = True)
ncdata_old.drop(['predicted_prob_black'], axis = 1, inplace = True)
ncdata['predicted_prob_black'] = ncdata['pprob_black'] 
ncdata_old['predicted_prob_black'] = ncdata_old['pprob_black'] 
ncdata.drop(['pprob_black'], axis = 1, inplace = True)
ncdata_old.drop(['pprob_black'], axis = 1, inplace = True)
ncdata['p_black_rd'] = ncdata['predicted_prob_black'].round(2)
ncdata_old['p_black_rd'] = ncdata_old['predicted_prob_black'].round(2)

ncdata['black_ind_wt'] = ncdata['black_ind']*ncdata['uswgt']
ncdata_old['black_ind_wt'] = ncdata_old['black_ind']*ncdata_old['uswgt']

ncdata['isEIC_old'] = ncdata['isEIC']
ncdata_old['isEIC_old'] = ncdata_old['isEIC']
ncdata['isEIC'] = ncdata['eic_ind'] # only differs for two TP's
ncdata_old['isEIC'] = ncdata_old['eic_ind'] # only differs for two TP's

ncdata['aud_no_research_audits_old'] = ncdata['aud_no_research_audits']
ncdata_old['aud_no_research_audits_old'] = ncdata_old['aud_no_research_audits']
ncdata['aud_no_research_audits'] = [1 if (x.find('[80]') == -1)
                         and (x.find(' 80]') == -1)
                         and (x.find('[80 ') == -1)
                         and (x.find('[91]') == -1)
                         and (x.find(' 91]') == -1)
                         and (x.find('[91 ') == -1)
                         and y == 1
                         else 0
                         for x, y in zip(ncdata.audit_source_code.astype(str), ncdata.audited)] # identical to aud_no_research_audits_old

ncdata_old['aud_no_research_audits'] = [1 if (x.find('[80]') == -1)
                         and (x.find(' 80]') == -1)
                         and (x.find('[80 ') == -1)
                         and (x.find('[91]') == -1)
                         and (x.find(' 91]') == -1)
                         and (x.find('[91 ') == -1)
                         and y == 1
                         else 0
                         for x, y in zip(ncdata_old.audit_source_code.astype(str), ncdata_old.audited)] # identical to aud_no_research_audits_old


grouped1 = ncdata.groupby(['p_black_rd','isEIC']).black_ind.mean().reset_index()
grouped1_old = ncdata_old.groupby(['p_black_rd','isEIC']).black_ind.mean().reset_index()
groupedwt1 = (ncdata.groupby(['p_black_rd','isEIC']).black_ind_wt.sum()/ncdata.groupby(['p_black_rd','isEIC']).uswgt.sum()).reset_index()
groupedwt1_old = (ncdata_old.groupby(['p_black_rd','isEIC']).black_ind_wt.sum()/ncdata_old.groupby(['p_black_rd','isEIC']).uswgt.sum()).reset_index()
ovwt1 = (ncdata.groupby('p_black_rd').black_ind_wt.sum()/ncdata.groupby('p_black_rd').uswgt.sum()).reset_index()
ovwt1_old = (ncdata_old.groupby('p_black_rd').black_ind_wt.sum()/ncdata_old.groupby('p_black_rd').uswgt.sum()).reset_index()


######################
##### Figure 1 (left): Distribution and Calibration of Race Imputation
######################

# probBlack_hist_fw.png

data['pBlack_bucket'] = pd.cut(data.predicted_prob_black,bins=[i/20 for i in range(22)],labels=[i/20 for i in range(21)], right=False,include_lowest=True).astype(float)
pct_black_counts = data.groupby('pBlack_bucket').taxpayer_id.count().reset_index()
fig,ax=plt.subplots(figsize=(10,8))
plt.bar(pct_black_counts.pBlack_bucket,pct_black_counts.taxpayer_id/pct_black_counts.taxpayer_id.sum(),width=0.05)
plt.xlabel('BIFSG-Predicted Probability Black')
plt.ylabel('Population Share')
plt.savefig(out+'probBlack_hist_fw.png')
plt.close('all')


######################
##### Figure 1 (right): Distribution and Calibration of Race Imputation
######################

# calibration_pblack_01_uswgt.png

fig,ax=plt.subplots(figsize=(10,8))
plt.plot(groupedwt1[groupedwt1.isEIC==0].p_black_rd,groupedwt1[groupedwt1.isEIC==0][0],label='Non-EITC Claimants',color=cne,alpha=ane)
plt.plot(groupedwt1[groupedwt1.isEIC==1].p_black_rd,groupedwt1[groupedwt1.isEIC==1][0],label='EITC Claimants',color=ce,alpha=ae)
plt.plot(ovwt1.p_black_rd,ovwt1[0],label='Overall')
plt.plot([0,1],[0,1],linestyle='--',color='black',linewidth=3)
plt.xlabel('BIFSG-Predicted Probability Black')
plt.ylabel('True Share Black')
plt.legend()
plt.savefig(out+'calibration_pblack_01_uswgt_new.png')
plt.close('all')

fig,ax=plt.subplots(figsize=(10,8))
plt.plot(groupedwt1_old[groupedwt1_old.isEIC==0].p_black_rd,groupedwt1_old[groupedwt1_old.isEIC==0][0],label='Non-EITC Claimants',color=cne,alpha=ane)
plt.plot(groupedwt1_old[groupedwt1.isEIC==1].p_black_rd,groupedwt1_old[groupedwt1_old.isEIC==1][0],label='EITC Claimants',color=ce,alpha=ae)
plt.plot(ovwt1_old.p_black_rd,ovwt1_old[0],label='Overall')
plt.plot([0,1],[0,1],linestyle='--',color='black',linewidth=3)
plt.xlabel('BIFSG-Predicted Probability Black')
plt.ylabel('True Share Black')
plt.legend()
plt.savefig(out+'calibration_pblack_01_uswgt.png') # this is correct, old weights are used
plt.close('all')

#########################
#### Fig 2: Audit Rate by Predicted Race Conditional on Self-Reported Race
#########################
# auditrate_p_bs.png

ncdata['unitwt'] = 1
ncb = ncdata[ncdata.black_ind==1]
ncb = ncb.dropna(subset=['predicted_prob_black'])
ncnb = ncdata[ncdata.black_ind==0]
ncnb = ncnb.dropna(subset=['predicted_prob_black'])
fig,_ = wb.weightedBinscatterMulti([ncnb,ncb],xlab='BIFSG-Predicted Probability Black',datalabels=['Non-Black','Black'], x_col='predicted_prob_black',y_col='aud_no_research_audits',ylab='Audit Rate (%)', w_col='unitwt',nbins=100,alphas=[anb,ab],reportCoefs=False,colors=[cnb,cb], best_lin_fit=True, scale_y=True, shapes=['o', 'x'])
b_lab = mlines.Line2D([],[], markeredgecolor = cb, alpha = ab, marker = 'x', linestyle = 'None', label = 'Black')
nb_lab = mlines.Line2D([],[],markeredgecolor = cnb, markerfacecolor = 'none', alpha = anb, marker = 'o', linestyle = 'None', label = 'Non-Black')
plt.legend(handles = [b_lab, nb_lab])
plt.ylim([0, .017])
plt.arrow(0.4,0.0035,0.2,0.00048, width = 0.00005, head_length = 0.025, head_width = 0.0003, overhang = 0.5, facecolor='#ababab', edgecolor = '#ababab')
plt.arrow(0.4,0.0099,0.2,0.0015, width = 0.00005, head_length = 0.025, head_width = 0.0003, overhang = 0.5, facecolor='#ababab', edgecolor = '#ababab')
plt.arrow(0.1,0.00257,0,0.0035, width = 0.002, head_length = 0.0005, head_width = 0.015, overhang = 0.5, facecolor='#595959', edgecolor = '#595959')
plt.arrow(0.9,0.0041,0,0.0074, width = 0.002, head_length = 0.0005, head_width = 0.015, overhang = 0.5, facecolor='#595959', edgecolor = '#595959')
#plt.savefig(out+'auditrate_p_bs_new7.png')
plt.savefig(out+'auditrate_p_bs.png')
plt.close('all')
# either the one of the exes or the Black line are slightly different from before


ncdata_old['unitwt'] = 1
ncb = ncdata_old[ncdata_old.black_ind==1]
ncb = ncb.dropna(subset=['predicted_prob_black'])
ncnb = ncdata_old[ncdata_old.black_ind==0]
ncnb = ncnb.dropna(subset=['predicted_prob_black'])
fig,_ = wb.weightedBinscatterMulti([ncnb,ncb],xlab='BIFSG-Predicted Probability Black',datalabels=['Non-Black','Black'], x_col='predicted_prob_black',y_col='aud_no_research_audits',ylab='Audit Rate (%)', w_col='unitwt',nbins=100,alphas=[anb,ab],reportCoefs=False,colors=[cnb,cb], best_lin_fit=True, scale_y=True, shapes=['o', 'x'])
b_lab = mlines.Line2D([],[], markeredgecolor = cb, alpha = ab, marker = 'x', linestyle = 'None', label = 'Black')
nb_lab = mlines.Line2D([],[],markeredgecolor = cnb, markerfacecolor = 'none', alpha = anb, marker = 'o', linestyle = 'None', label = 'Non-Black')
plt.legend(handles = [b_lab, nb_lab])
plt.ylim([0, .017])
plt.arrow(0.4,0.0035,0.2,0.00048, width = 0.00005, head_length = 0.025, head_width = 0.0003, overhang = 0.5, facecolor='#ababab', edgecolor = '#ababab')
plt.arrow(0.4,0.0099,0.2,0.0015, width = 0.00005, head_length = 0.025, head_width = 0.0003, overhang = 0.5, facecolor='#ababab', edgecolor = '#ababab')
plt.arrow(0.1,0.00257,0,0.0035, width = 0.002, head_length = 0.0005, head_width = 0.015, overhang = 0.5, facecolor='#595959', edgecolor = '#595959')
plt.arrow(0.9,0.0041,0,0.0074, width = 0.002, head_length = 0.0005, head_width = 0.015, overhang = 0.5, facecolor='#595959', edgecolor = '#595959')
plt.savefig(out+'auditrate_p_bs_old.png')
plt.close('all')

#########################
#### Fig 3 (left): Estimated Audit Rate by Race
#########################

# bs_BISG_all.png
fig, dat = wb.weightedBinscatterQnD(data,xcol='predicted_prob_black',ycol='aud_no_research_audits',wcol='unitweight',nbins=100,xlabel='BIFSG-Predicted Probability Black',ylabel='Audit Rate (%)',includeCoef=False, shuff_0=True, best_lin_fit=False, set_ymin_zero=True, scale_y=True)
ax = fig.gca()
ax.set_ylim([0, 0.024])
fig.savefig(out+'bs_BISG_all.png')

#########################
#### Fig 5 (left): Estimated Audit Rates by Race and EITC CLaim Status
#########################

# BISG_eitc_breakout.png
datasets = [data.loc[data.isEIC==1], data.loc[data.isEIC==0]]
fig, dat = wb.weightedBinscatterMulti(datasets, 
                                        x_col='predicted_prob_black', 
                                        y_col='aud_no_research_audits', 
                                        w_col='unitweight', 
                                        nbins=100, 
                                        xlab='Predicted Probability Black', 
                                        ylab='Audit Rate (%)',
                                        datalabels=['EITC', 'Non-EITC'],
                                        colors=[ce, cne],
                                        alphas=[ae, ane],
                                        shapes=['o', 'x'],
                                        styles=['-', ':'],
                                        xcollist=None,
                                        ycollist=None,
                                        wcollist=None,
                                        reportCoefs=False,
                                        scale_y=True,
                                        best_lin_fit=False,
                                        ylims=[0,0.05])

eitc_lab = mlines.Line2D([],[], markeredgecolor = ce, markerfacecolor = 'none', alpha = ae, marker = 'o', linestyle = 'None', label = 'EITC')
noneitc_lab = mlines.Line2D([],[],markeredgecolor = cne, alpha = ane, marker = 'x', linestyle = 'None', label = 'Non-EITC')
plt.legend(handles = [eitc_lab, noneitc_lab])

fig.savefig(out+'BISG_eitc_breakout.png')

#########################
#### Fig 6: Audit Rate Disparities by EITC Subgroup
#########################

# Weighted_Estimator_graph_proba_startwEITC.png -- probabilistic
data['unitwt'] = 1
data['predicted_prob_nonblack'] = 1 - data['predicted_prob_black']
weighted_estimates = et.makeEstimatorTable(data,out)
weighted_estimates.Population = weighted_estimates.Population.map({'All TP':'All Taxpayers','EIC':'EITC','Non-EIC':'Non-EITC','Joint EIC':'Joint EITC','Nonjoint EIC':'Single EITC','Nonjoint Male EIC':'Single Male EITC','Nonjoint Nonmale EIC':'Single Female EITC','Nonjoint Male EIC w/Deps':'Single Male EITC w/ Deps','Nonjoint Male EIC no Deps':'Single Male EITC w/o Deps'})
et.graphWeightedEstimators(weighted_estimates[~weighted_estimates.Population.isin(['All Taxpayers', 'EITC', 'Non-EITC'])],out=out,colb=cb,alpb=ab,colnb=cnb,alpnb=anb,rotation=45,suffix='_proba_startwEITC')

#########################
#### Fig A.6: Audit Rate Disparities by EITC Subgroup (Linear Estimator)
#########################
# Weighted_Estimator_graph_lin_startwEITC.png -- linear

weighted_estimates = et.makeEstimatorTable(data, out, estimator='linear')
weighted_estimates.Population = weighted_estimates.Population.map({'All TP':'All Taxpayers','EIC':'EITC','Non-EIC':'Non-EITC','Joint EIC':'Joint EITC','Nonjoint EIC':'Single EITC','Nonjoint Male EIC':'Single Male EITC','Nonjoint Nonmale EIC':'Single Female EITC','Nonjoint Male EIC w/Deps':'Single Male EITC w/ Deps','Nonjoint Male EIC no Deps':'Single Male EITC w/o Deps'})
et.graphWeightedEstimators(weighted_estimates[~weighted_estimates.Population.isin(['All Taxpayers', 'EITC', 'Non-EITC'])],out=out,colb=cb,alpb=ab,colnb=cnb,alpnb=anb,rotation=45,suffix='_lin_startwEITC')

#########################
#### Fig A.5 (right): Estimated Audit Rates by EITC Claim Status
#########################

# audit_rate_eitc.png

ar_eic = data.groupby('isEIC').aud_no_research_audits.mean().reset_index()
ar_eic = ar_eic.sort_values('isEIC', ascending=False)
ar_eic['isEIC'] = ar_eic['isEIC'].map({1:'EITC',0:'Non-EITC'})
print(ar_eic)
plt.close('all')
fig,ax=plt.subplots(figsize=(10,8))
bars = ax.bar(x=ar_eic.isEIC,height=(ar_eic.aud_no_research_audits*100),align='center',width=0.7)
plt.ylabel('Audit Rate (%)')
bars[0].set_color(ce)
bars[0].set_alpha(ae)
bars[1].set_color(cne)
bars[1].set_alpha(ane)
et.labelbar(ax,[bars[0]],2,20)
et.labelbar(ax,[bars[1]],2,20,additiveeps=-0.0001)
ax.set_xticklabels(ar_eic.isEIC,fontsize=20)
plt.savefig(out+'audit_rate_eitc.png')

#########################
#### Fig A.5 (left): Estimated Audit Rates by EITC Claim Status
#########################

# bs_EITCClaimants_all.png

fig,dat = wb.weightedBinscatterQnD(data[~data.isEIC.isna()],xcol='predicted_prob_black',ycol='isEIC',wcol='unitweight',nbins=100,xlabel='Predicted Probability Black',ylabel='EITC Claim Rate (%)',includeCoef=False, shuff_0=True, best_lin_fit=False, scale_y=True, set_ymin_zero=True, set_ymax_one=True)
fig.savefig(out+'bs_EITCClaimants_all.png')


#########################
#### Fig A.11 (old): Racial Composition of EITC Activity Codes
#########################

ac_pct_black = data.groupby('total_pos_inc_class').predicted_prob_black.mean().reset_index()
idx = np.arange(2)
ac_pct_black = ac_pct_black[ac_pct_black.total_pos_inc_class.isin([70,71])]
fig,ax=plt.subplots(figsize=(10,8))
b = ax.bar(idx,ac_pct_black.predicted_prob_black,color=cb,alpha=ab,label='Black', align='center')
nb = ax.bar(idx,1-ac_pct_black.predicted_prob_black,color=cnb,alpha=anb,bottom=ac_pct_black.predicted_prob_black,label='Non-Black', align='center')
ax.set_xticks(idx)
ax.set_xticklabels([270,271])
et.labelbar(ax,b, fontsize=18)
plt.xlabel('Activity Code')
plt.ylabel('Estimated Fraction Black')
plt.legend()
plt.savefig(out+'frac_black_by_ac_eitconly.png')


#########################
#### Fig A.14: Estimated Audit Rate DIsparity by Year
#########################

## Create disparity_by_year.csv from Tom's data, "Disparity over time data Sept 2022.xlsx"

# calculate 2014 estimates and adjust disparity_by_year.csv accordingly
chenEstimate(data, "predicted_prob_black", "aud_no_research_audits")
chenEstimate(data, "predicted_prob_black", "aud_no_research_audits")[1] / chenEstimate(data, "predicted_prob_black", "aud_no_research_audits")[2]

data_eic = data.loc[data.isEIC==1]
chenEstimate(data_eic, "predicted_prob_black", "aud_no_research_audits")
chenEstimate(data_eic, "predicted_prob_black", "aud_no_research_audits")[1] / chenEstimate(data_eic, "predicted_prob_black", "aud_no_research_audits")[2]

data_disp_year = pd.read_csv("/REDACTED/disparity_by_year.csv")


### full pop disparity bar by year plot
data_all = data_disp_year[data_disp_year['Group'] == "All"]
idx = np.arange(len(data_all.Year.unique()))
bwidth = 0.2

ax = plt.subplot(111)
arnb = ax.bar(idx/2, data_all.NonBlack ,align = 'center', width = (bwidth-0.03), color = 'purple', alpha = 0.4, label = 'Non-Black', capsize = 15)
arb = ax.bar(idx/2-bwidth,data_all.Black, align = 'center', width = (bwidth-0.03), color = 'purple', alpha = 1, label = 'Black', capsize = 15)

ratio = ax.scatter(idx/2- bwidth/2, data_all.Ratio, color = "Black", label = "Ratio")

ax.set_xticks((idx/2) - bwidth/2)
ax.set_xticklabels(data_all.Year.unique().astype(int))
ax.set_xlabel('Year')
ax.set_ylabel('Audit Rate (%) and Black/Non-Black Ratio')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles)
plt.savefig('/REDACTED/fullpop_audit_disparity_year.png')
plt.close()


#########################
#### Fig A.15: Estimated Audit Rate DIsparity Among EITC Claimants by Year
#########################

### eitc pop disparity bar by year plot
data_eitc = data_disp_year[data_disp_year['Group'] == "EITC"]
idx = np.arange(len(data_eitc.Year.unique()))
bwidth = 0.2

ax = plt.subplot(111)
arnb = ax.bar(idx/2, data_eitc.NonBlack ,align = 'center', width = (bwidth-0.03), color = 'purple', alpha = 0.4, label = 'Non-Black', capsize = 15)
arb = ax.bar(idx/2-bwidth,data_eitc.Black, align = 'center', width = (bwidth-0.03), color = 'purple', alpha = 1, label = 'Black', capsize = 15)

ratio = ax.scatter(idx/2- bwidth/2, data_eitc.Ratio, color = "Black", label = "Ratio")
'''
ax.set_xticks((idx/2) - bwidth/2)
ax.set_xticklabels(data_eitc.Year.unique().astype(int))
ax.set_xlabel('Year')
ax.set_ylabel('Audit Rate (%) and Black/Non-Black Ratio')
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles)
plt.savefig('/REDACTED/eitcpop_audit_disparity_year.png')
plt.close()

